{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: switch to AMI\n",
    "PROTOCOL = 'Debug.SpeakerDiarization.Debug'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: update this tutorial to do fine tuning of a model pretrained on DIHARD"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overlapped speech detection with `pyannote.audio`\n",
    "\n",
    "Overlapped speech detection (OSD) is the task of detecting regions where at least two speakers are speaking at the same time. In this notebook, we will train and evaluate an OSD pipeline on Debug database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.database import get_protocol, FileFinder\n",
    "protocol = get_protocol(PROTOCOL, preprocessors={\"audio\": FileFinder()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pyannote.database` *protocols* usually define \n",
    "* a training set: `for training_file in protocol.train(): ...`, \n",
    "* a validation set: `for validation_file in protocol.development(): ...` \n",
    "* an evaluation set `for evaluation_file in protocol.test(): ...`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's listen to the first training file and visualize its reference annotation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_training_file = next(protocol.train())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.utils.preview import listen\n",
    "listen(first_training_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_training_file['annotation']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The expected output of a perfect overlapped speech detection pipeline would look like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.pipelines.overlapped_speech_detection import OracleOverlappedSpeechDetection\n",
    "oracle_osd = OracleOverlappedSpeechDetection()\n",
    "\n",
    "oracle_osd(first_training_file).get_timeline()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We initialize an OSD *task* that describes how the model will be trained:\n",
    "\n",
    "* `protocol` indicates that we will use files available in `protocol.train()`.\n",
    "* `duration=2.` and `batch_size=16` indicates that the model will ingest batches of 16 two seconds long audio chunks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.tasks import OverlappedSpeechDetection\n",
    "osd = OverlappedSpeechDetection(protocol, duration=2., batch_size=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We initialize the *model*: it needs to know about the task (`task=osd`) for which it is being trained for:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n",
    "model = SimpleSegmentationModel(task=osd)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that everything is ready, let's train with `pytorch-ligthning`!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "trainer = pl.Trainer(max_epochs=10)\n",
    "trainer.fit(model, osd)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "Once trained, we will apply the model on a test file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_file = next(protocol.test())\n",
    "# here we use a test file provided by the protocol, but it could be any audio file\n",
    "# e.g. test_file = \"/path/to/test.wav\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because the model was trained on 2s audio chunks and that test files are likely to be much longer than that, we wrap the `model` with an `Inference` instance: it will take care of sliding a 2s window over the whole file and aggregate the output of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio import Inference\n",
    "inference = Inference(model)\n",
    "osd_probability = inference(test_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "osd_probability"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Almost there! To obtain the final overlapped speech regions, we need to apply a detection threshold.  \n",
    "For that, we rely on the overlapped speech detection pipeline whose hyper-parameters are set manually:\n",
    "- `onset=0.5`: mark region as `active` when probability goes above 0.5\n",
    "- `offset=0.5`: switch back to `inactive` when probability goes below 0.5\n",
    "- `min_duration_on=0.1`: remove `active` regions shorter than 100ms\n",
    "- `min_duration_off=0.1`: fill `inactive` regions shorter than 100ms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.pipelines import OverlappedSpeechDetection as OverlappedSpeechDetectionPipeline\n",
    "pipeline = OverlappedSpeechDetectionPipeline(scores=inference).instantiate(\n",
    "    {\"onset\": 0.5, \"offset\": 0.5, \"min_duration_on\": 0.1, \"min_duration_off\": 0.1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we go:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(test_file).get_timeline()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizing pipeline hyper-parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "W can try to optimize the hyper-parameters (that we chose manually above) on the validation set to get better performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to make things faster, we run the inference once and for all... \n",
    "validation_files = list(protocol.development())\n",
    "for file in validation_files:\n",
    "    file['osd'] = inference(file)\n",
    "# ... and tell the pipeline to load OSD scores directly from files\n",
    "pipeline = OverlappedSpeechDetectionPipeline(scores=\"osd\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.pipeline import Optimizer\n",
    "optimizer = Optimizer(pipeline)\n",
    "optimizer.tune(validation_files, n_iterations=200, show_progress=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There you go: better hyper-parameters that should lead to better results!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_pipeline = OverlappedSpeechDetectionPipeline(scores=inference).instantiate(optimizer.best_params)\n",
    "optimized_pipeline(test_file).get_timeline()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
